Maximum level sum of a binary tree [DFS, BFS]

Time: O(N); Space: O(W); medium

Given the root of a binary tree, the level of its root is 1, the level of its children is 2, and so on.

Return the smallest level X such that the sum of all the values of nodes at level X is maximal.

Example 1:

Input: [1, 7, 0, 7, -8, null, null]

Output: 2

Explanation:

  • Level 1 sum = 1.

  • Level 2 sum = 7 + 0 = 7.

  • Level 3 sum = 7 + -8 = -1.

  • So we return the level with the maximum sum which is level 2.

Notes:

  • The number of nodes in the given tree is between 1 and 10^4.

  • -10^5 <= node.val <= 10^5

Hints:

  1. Calculate the sum for each level then find the level with the maximum sum.

  2. How can you traverse the tree ?

  3. How can you sum up the values for every level ?

  4. Use DFS or BFS to traverse the tree keeping the level of each node, and sum up those values with a map or a frequency array.

[1]:
# Definition for a binary tree node.
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[2]:
import collections

class Solution1(object):
    """
    DFS solution
    Time:  O(n)
    Space: O(h)
    """
    def maxLevelSum(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        def dfs(node, i, level_sums):
            if not node:
                return
            if i == len(level_sums):
                level_sums.append(0)
            level_sums[i] += node.val
            dfs(node.left, i+1, level_sums)
            dfs(node.right, i+1, level_sums)

        level_sums = []
        dfs(root, 0, level_sums)
        return level_sums.index(max(level_sums)) + 1
[3]:
s = Solution1()
root = TreeNode(1)
root.left = TreeNode(7)
root.right = TreeNode(0)
root.left.left = TreeNode(7)
root.left.right = TreeNode(-8)
assert s.maxLevelSum(root) == 2
[4]:
class Solution2(object):
    """
    BFS solution
    Time:  O(n)
    Space: O(w)
    """
    def maxLevelSum(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        result, level, max_total = 0, 1, float("-inf")
        q = collections.deque([root])
        while q:
            total = 0
            for _ in range(len(q)):
                node = q.popleft()
                total += node.val
                if node.left:
                    q.append(node.left)
                if node.right:
                    q.append(node.right)
            if total > max_total:
                result, max_total = level, total
            level += 1
        return result
[5]:
s = Solution2()
root = TreeNode(1)
root.left = TreeNode(7)
root.right = TreeNode(0)
root.left.left = TreeNode(7)
root.left.right = TreeNode(-8)
assert s.maxLevelSum(root) == 2